import torch
import numpy as np


def get_state_statistics():
    # TODO: other maps have different means and stds, what to do?
    mean = np.array(
        [2.09876, 2.09931, 2.09984, 2.10028, 2.10045, 2.10068, 2.10067, 2.10062, 2.10094, 2.10138, 2.10175, 2.10205,
         2.10238, 2.10256, 2.10265, 2.10299, 2.10344, 2.10406, 2.10491, 2.10589, 2.10681, 2.10741, 2.10786, 2.10859,
         2.10953, 2.11053, 2.11154, 2.11251, 2.11323, 2.11385, 2.11462, 2.11549, 2.11645, 2.11718, 2.11749, 2.1176,
         2.11777, 2.11766, 2.11773, 2.11837, 2.1188, 2.11859, 2.1181, 2.11772, 2.1176, 2.11801, 2.11861, 2.11923,
         2.11948, 2.11946, 2.11928, 2.11924, 2.11959, 2.12011, 2.12047, 2.12056, 2.12052, 2.12039, 2.12043, 2.12068,
         2.12116, 2.12141, 2.1214, 2.12144, 2.12143, 2.12166, 2.12198, 2.12237, 2.12265, 2.12278, 2.12294,
         2.12288, 0.0, 0.0, 2.30986])
    mean = torch.from_numpy(mean).float()
    std = np.array(
        [1.63815, 1.63882, 1.64022, 1.64186, 1.64313, 1.64411, 1.64401, 1.64434, 1.64557, 1.64721, 1.64877, 1.64977,
         1.65044, 1.65012, 1.64996, 1.65026, 1.65162, 1.65289, 1.65421, 1.65535, 1.65579, 1.6553, 1.65472, 1.65601,
         1.65793, 1.65879, 1.6588, 1.65836, 1.65719, 1.65635, 1.65707, 1.65773, 1.6583, 1.65732, 1.65581, 1.65438,
         1.65376, 1.65345, 1.6545, 1.6566, 1.65705, 1.65489, 1.65249, 1.65084, 1.65003, 1.65158, 1.65339, 1.65523,
         1.65525, 1.65451, 1.65341, 1.65292, 1.65404, 1.65531, 1.65629, 1.65594, 1.65482, 1.65363, 1.65295, 1.65293,
         1.65371, 1.654, 1.65289, 1.6515, 1.65009, 1.64942, 1.64911, 1.64904, 1.64909, 1.64876, 1.64799, 1.6474,
         0.70711, 0.70711, 1.31035])
    std = torch.from_numpy(std).float()

    return mean, std


def apply_normalization(x, mean, std):
    for i in range(x.size()[1]):
        x[0, i, :].data -= mean
        x[0, i, :].data /= std


def apply_normalization_icm(x, mean, std):
    for i in range(x.size()[0]):
        x[i, :].data -= mean
        x[i, :].data /= std
